Enet-Camvid
Pytorch Implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation trained on the CamVid Dataset
Project Repository: https://github.com/soumik12345/enet
# Imports for Model Creation
import torch
from torch.nn import functional as F
from torch.nn import (
Module, Conv2d, ReLU, PReLU, Dropout2d, AvgPool2d,
Upsample, MaxPool2d, Sequential, MaxUnpool2d,
BatchNorm2d, AdaptiveAvgPool2d, ConvTranspose2d
)
from IPython.display import SVG
Inspired by Swish Activation Function (Paper), Mish is a Self Regularized Non-Monotonic Neural Activation Function. Mish Activation Function can be mathematically represented by the following formula:
$f(x) = x * tanh(ln(1 + e^{x}))$
It can also be represented using the Softplus Activation Function: $f(x) = x * tanh( \varsigma (x))$
where, $\varsigma (x) = ln(1 + e^{x})$
class Mish(Module):
def __init(self):
super().__init__()
def forward(self, input):
return input * torch.tanh(F.softplus(input))
Since Pytorch does not explicitly have any Module for Activation unlike Tensorflow, we can easily implement it. The following Module could be modified to incorporate for any number of activation functions, each of which can be accessed with a string label.
class Activation(Module):
def __init__(self, name='relu'):
super().__init__()
self.name = name
if name == 'relu':
self.act = ReLU()
elif name == 'prelu':
self.act = PReLU()
elif name == 'mish':
self.act = Mish()
def forward(self, input):
return self.act(input)
Enet Initial Block
The Initial Block is the first block of the Enet Model. It consists of 2 branches, a convolutional layer (out_channels=13, kernel_size=3, stride=2) which we would call the main branch in our implementation and a MaxPooling layer which is performed with non-overlapping 2x2 windows which is a secondary block. We would perform BatchNormalization and a Non-linear Activation on the concatenation of the two branches. The Input block would have 16 output channels.
#collapse-hide
SVG('enet_initial_block_xc7itf.svg')
class InitialBlock(Module):
def __init__(self, in_channels, out_channels, bias=False, activation='relu'):
super().__init__()
self.main_branch = Conv2d(
in_channels, out_channels - 3,
kernel_size=3, stride=2,
padding=1, bias=bias
)
self.secondary_branch = MaxPool2d(3, stride=2, padding=1)
self.batch_norm = BatchNorm2d(out_channels)
self.activation = Activation(activation)
def forward(self, x):
main = self.main_branch(x)
secondary = self.secondary_branch(x)
output = torch.cat((main, secondary), 1)
output = self.batch_norm(output)
output = self.activation(output)
return output
Regular Bottleneck Block
In case of the Regular Bottleneck Block which is the most widely used block in the ENet architecture, the secondary block has no operations. The middle convolution blocks are either 3x3 regular convolutional block or a 5x5 asymmetric convolutional block. All the convolutional blocks in the main branch have Batchnormalization and respective Activation after them. The main branch is regularized by a Dropout operation.
#collapse-hide
SVG('enet_regular_bottleneck_esc4ir.svg')
class RegularBottleneckBlock(Module):
def __init__(
self, channels, internal_ratio=4, kernel_size=3, padding=0,
dilation=1, asymmetric=False, dropout_prob=0, bias=False, activation='relu'):
super().__init__()
internal_channels = channels // internal_ratio
### Main Branch ###
# Block 1 Conv 1x1
self.main_conv_block_1 = Sequential(
Conv2d(
channels, internal_channels,
kernel_size=1, stride=1, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation)
)
# Block 2
if asymmetric:
self.main_conv_block_2 = Sequential(
Conv2d(
internal_channels, internal_channels,
kernel_size=(kernel_size, 1), stride=1,
padding=(padding, 0), dilation=dilation, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation),
Conv2d(
internal_channels, internal_channels,
kernel_size=(1, kernel_size), stride=1,
padding=(0, padding), dilation=dilation, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation),
)
else:
self.main_conv_block_2 = Sequential(
Conv2d(
internal_channels, internal_channels,
kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation),
)
# Block 3 Conv 1x1
self.main_conv_block_3 = Sequential(
Conv2d(
internal_channels, channels,
kernel_size=1, stride=1, bias=bias
),
BatchNorm2d(channels),
Activation(activation),
)
# Dropout Regularization
self.dropout = Dropout2d(p=dropout_prob)
# Activation
self.activation = Activation(activation)
def forward(self, x):
secondary_branch = x
main_branch = self.main_conv_block_1(x)
main_branch = self.main_conv_block_2(main_branch)
main_branch = self.main_conv_block_3(main_branch)
main_branch = self.dropout(main_branch)
output = main_branch + secondary_branch
output = self.activation(output)
return output
#collapse-hide
SVG('enet_downsampling_bottleneck-1_ysayci.svg')
class DownsampleBottleneckBlock(Module):
def __init__(
self, in_channels, out_channels, internal_ratio=4,
return_indices=False, dropout_prob=0, bias=False, activation='relu'):
super().__init__()
internal_channels = in_channels // internal_ratio
self.return_indices = return_indices
### Main Branch ###
# Block 1 Conv 1x1
self.main_conv_block_1 = Sequential(
Conv2d(
in_channels, internal_channels,
kernel_size=2, stride=2, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation)
)
# Block 2 Conv 3x3
self.main_conv_block_2 = Sequential(
Conv2d(
internal_channels, internal_channels,
kernel_size=3, stride=1, padding=1, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation)
)
# Block 2 Conv 1x1
self.main_conv_block_3 = Sequential(
Conv2d(
internal_channels, out_channels,
kernel_size=1, stride=1, bias=bias
),
BatchNorm2d(out_channels),
Activation(activation)
)
### Secondary Branch ###
self.secondary_maxpool = MaxPool2d(
2, stride=2,
return_indices=return_indices
)
# Dropout Regularization
self.dropout = Dropout2d(p=dropout_prob)
# Activation
self.activation = Activation(activation)
def forward(self, x):
# Main Branch
main_branch = self.main_conv_block_1(x)
main_branch = self.main_conv_block_2(main_branch)
main_branch = self.main_conv_block_3(main_branch)
# Secondary Branch
if self.return_indices:
secondary_branch, max_indices = self.secondary_maxpool(x)
else:
secondary_branch = self.secondary_maxpool(x)
# Padding
n, ch_main, h, w = main_branch.size()
ch_sec = secondary_branch.size()[1]
padding = torch.zeros(n, ch_main - ch_sec, h, w)
if secondary_branch.is_cuda:
padding = padding.cuda()
# Concatenate
secondary_branch = torch.cat((secondary_branch, padding), 1)
output = secondary_branch + main_branch
output = self.activation(output)
if self.return_indices:
return output, max_indices
else:
return output
#collapse-hide
SVG('enet_upsampling_block_du1lmw.svg')
class UpsampleBottleneckBlock(Module):
def __init__(
self, in_channels, out_channels,
internal_ratio=4, dropout_prob=0,
bias=False, activation='relu'):
super().__init__()
internal_channels = in_channels // internal_ratio
### Main Branch ###
# Block 1 Conv 1x1
self.main_branch_conv_1 = Sequential(
Conv2d(
in_channels, internal_channels,
kernel_size=1, bias=bias
),
BatchNorm2d(internal_channels),
Activation(activation)
)
# Block 2 Transposed Convolution
self.main_branch_transpose_conv_2 = ConvTranspose2d(
internal_channels, internal_channels,
kernel_size=2, stride=2, bias=bias
)
self.main_branch_bn_2 = BatchNorm2d(internal_channels)
self.main_branch_act_2 = Activation(activation)
# Block 3 Conv 1x1
self.main_branch_conv_3 = Sequential(
Conv2d(
internal_channels, out_channels,
kernel_size=1, bias=bias
),
BatchNorm2d(out_channels),
Activation(activation)
)
### Secondary Branch ###
self.secondary_conv = Sequential(
Conv2d(
in_channels, out_channels,
kernel_size=1, bias=bias
),
BatchNorm2d(out_channels)
)
self.secondary_unpool = MaxUnpool2d(kernel_size=2)
# Dropout Regularization
self.dropout = Dropout2d(p=dropout_prob)
# Activation
self.activation = Activation(activation)
def forward(self, x, max_indices, output_size):
# Main Branch
main_branch = self.main_branch_conv_1(x)
main_branch = self.main_branch_transpose_conv_2(main_branch, output_size=output_size)
main_branch = self.main_branch_bn_2(main_branch)
main_branch = self.main_branch_act_2(main_branch)
main_branch = self.main_branch_conv_3(main_branch)
main_branch = self.dropout(main_branch)
# Secondary Branch
secondary_branch = self.secondary_conv(x)
secondary_branch = self.secondary_unpool(
secondary_branch, max_indices,
output_size=output_size
)
# Concatenate
output = main_branch + secondary_branch
output = self.activation(output)
return output
Building Enet
The overall architecture of Enet is summarized in the following table. The whole architecture is divided into 6 parts or stages. Stage 0 consists of the Initial Block only. Stages 1-3 make up the encoder part of the network which downsamples the input. Stages 4-5 makes up the decoder, which upsamples the input to create the output.
| Name | Type | Output Size |
|---|---|---|
| Initial | 16x256x256 |
|
| ------- | ------- | ------- |
| Bottleneck_1 | Downsampling | 64x128x128 |
| RegularBottleneck_1_1 | 64x128x128 |
|
| RegularBottleneck_1_2 | 64x128x128 |
|
| RegularBottleneck_1_3 | 64x128x128 |
|
| RegularBottleneck_1_4 | 64x128x128 |
|
| ------- | ------- | ------- |
| Bottleneck_2 | Downsampling | 128x64x64 |
| RegularBottleneck_2_1 | 128x64x64 |
|
| RegularBottleneck_2_2 | Dilated 2 | 128x64x64 |
| RegularBottleneck_2_3 | Asymmetric 5 | 128x64x64 |
| RegularBottleneck_2_4 | Dilated 4 | 128x64x64 |
| RegularBottleneck_2_5 | 128x64x64 |
|
| RegularBottleneck_2_6 | Dilated 8 | 128x64x64 |
| RegularBottleneck_2_7 | Asymmetric 5 | 128x64x64 |
| RegularBottleneck_2_8 | Dilated 16 | 128x64x64 |
| ------- | ------- | ------- |
| RegularBottleneck_3 | 128x64x64 |
|
| RegularBottleneck_3_1 | Dilated 2 | 128x64x64 |
| RegularBottleneck_3_2 | Assymetric 5 | 128x64x64 |
| RegularBottleneck_3_3 | Dilated 4 | 128x64x64 |
| RegularBottleneck_3_4 | 128x64x64 |
|
| RegularBottleneck_3_5 | Dilated 8 | 128x64x64 |
| RegularBottleneck_3_6 | Asymmetric 5 | 128x64x64 |
| RegularBottleneck_3_7 | Dilated 16 | 128x64x64 |
| ------- | ------- | ------- |
| Bottleneck_4 | Upsampling | 64x128x128 |
| Bottleneck_4_1 | 64x128x128 |
|
| Bottleneck_4_2 | 64x128x128 |
|
| ------- | ------- | ------- |
| Bottleneck_5 | Upsampling | 16x256x256 |
| Bottleneck_5_1 | 16x256x256 |
|
| ------- | ------- | ------- |
| Transposed_Conv | Cx512x512 |
class Enet(Module):
def __init__(self, num_classes, encoder_activation='mish', decoder_activation='relu'):
super().__init__()
# Initial Block
self.initial_block = InitialBlock(3, 16, activation=encoder_activation)
### Encoding Stages ###
# Stage 1
self.down_bottleneck_1 = DownsampleBottleneckBlock(
16, 64, return_indices=True,
dropout_prob=0.01, activation=encoder_activation
)
self.bottleneck_1_1 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.01,
activation=encoder_activation
)
self.bottleneck_1_2 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.01,
activation=encoder_activation
)
self.bottleneck_1_3 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.01,
activation=encoder_activation
)
self.bottleneck_1_4 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.01,
activation=encoder_activation
)
# Stage 2
self.down_bottleneck_2 = DownsampleBottleneckBlock(
64, 128, return_indices=True,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_2_1 = RegularBottleneckBlock(
128, padding=1, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_2_2 = RegularBottleneckBlock(
128, dilation=2,
padding=2, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_2_3 = RegularBottleneckBlock(
128, kernel_size=5, padding=2,
asymmetric=True, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_2_4 = RegularBottleneckBlock(
128, dilation=4, padding=4,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_2_5 = RegularBottleneckBlock(
128, padding=1, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_2_6 = RegularBottleneckBlock(
128, dilation=8, padding=8,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_2_7 = RegularBottleneckBlock(
128, kernel_size=5, asymmetric=True,
padding=2, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_2_8 = RegularBottleneckBlock(
128, dilation=16, padding=16,
dropout_prob=0.1, activation=encoder_activation
)
# Stage 3
self.regular_bottleneck_3 = RegularBottleneckBlock(
128, padding=1, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_3_1 = RegularBottleneckBlock(
128, dilation=2, padding=2,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_3_2 = RegularBottleneckBlock(
128, kernel_size=5, padding=2,
asymmetric=True, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_3_3 = RegularBottleneckBlock(
128, dilation=4, padding=4,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_3_4 = RegularBottleneckBlock(
128, padding=1, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_3_5 = RegularBottleneckBlock(
128, dilation=8, padding=8,
dropout_prob=0.1, activation=encoder_activation
)
self.bottleneck_3_6 = RegularBottleneckBlock(
128, kernel_size=5, asymmetric=True,
padding=2, dropout_prob=0.1,
activation=encoder_activation
)
self.bottleneck_3_7 = RegularBottleneckBlock(
128, dilation=16, padding=16,
dropout_prob=0.1, activation=encoder_activation
)
# Stage 4
self.upsample_4 = UpsampleBottleneckBlock(
128, 64, dropout_prob=0.1,
activation=decoder_activation
)
self.bottleneck_4_1 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.1,
activation=decoder_activation
)
self.bottleneck_4_2 = RegularBottleneckBlock(
64, padding=1, dropout_prob=0.1,
activation=decoder_activation
)
# Stage 5
self.upsample_5 = UpsampleBottleneckBlock(
64, 16, dropout_prob=0.1,
activation=decoder_activation
)
self.bottleneck_5 = RegularBottleneckBlock(
16, padding=1, dropout_prob=0.1,
activation=decoder_activation
)
self.transposed_conv = ConvTranspose2d(
16, num_classes, kernel_size=3,
stride=2, padding=1, output_padding=1, bias=False
)
def forward(self, x):
# Initial Block
input_size = x.size()
x = self.initial_block(x)
# Stage 1
input_size_1 = x.size()
x, max_indices_1 = self.down_bottleneck_1(x)
x = self.bottleneck_1_1(x)
x = self.bottleneck_1_2(x)
x = self.bottleneck_1_3(x)
x = self.bottleneck_1_4(x)
# Stage 2
input_size_2 = x.size()
x, max_indices_2 = self.down_bottleneck_2(x)
x = self.bottleneck_2_1(x)
x = self.bottleneck_2_2(x)
x = self.bottleneck_2_3(x)
x = self.bottleneck_2_4(x)
x = self.bottleneck_2_5(x)
x = self.bottleneck_2_6(x)
x = self.bottleneck_2_7(x)
x = self.bottleneck_2_8(x)
# Stage 3
x = self.regular_bottleneck_3(x)
x = self.bottleneck_3_1(x)
x = self.bottleneck_3_2(x)
x = self.bottleneck_3_3(x)
x = self.bottleneck_3_4(x)
x = self.bottleneck_3_5(x)
x = self.bottleneck_3_6(x)
x = self.bottleneck_3_7(x)
# Stage 4
x = self.upsample_4(x, max_indices_2, output_size=input_size_2)
x = self.bottleneck_4_1(x)
x = self.bottleneck_4_2(x)
# Stage 5
x = self.upsample_5(x, max_indices_1, output_size=input_size_1)
x = self.bottleneck_5(x)
x = self.transposed_conv(x)
return x
print('GPU:', torch.cuda.get_device_name(0))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
enet = Enet(12, encoder_activation='prelu', decoder_activation='relu')
enet = enet.to(device)
print(enet)
!mkdir camvid
%cd camvid
!wget https://www.dropbox.com/s/ej1gx48bxqbtwd2/CamVid.zip?dl=0 -O CamVid.zip
!unzip -qq CamVid.zip
!rm CamVid.zip
%cd ..
import numpy as np
import torch, os
from glob import glob
from time import time
from tqdm import tqdm
from PIL import Image
from torch.nn import functional as F
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, ToPILImage
class CamVidDataset(Dataset):
def __init__(self, images, labels, height, width):
self.images = images
self.labels = labels
self.height = height
self.width = width
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image_id = self.images[index]
label_id = self.labels[index]
# Read Image
x = Image.open(image_id)
x = [np.array(x)]
x = np.stack(x, axis=2)
x = torch.tensor(x).transpose(0, 2).transpose(1, 3) # Convert to N, C, H, W
# Read Mask
y = Image.open(label_id)
y = [np.array(y)]
y = torch.tensor(y)
return x.squeeze(), y.squeeze()
Get the image file lists
train_images = sorted(glob('./camvid/train/*'))
train_labels = sorted(glob('./camvid/trainannot/*'))
val_images = sorted(glob('./camvid/val/*'))
val_labels = sorted(glob('./camvid/valannot/*'))
test_images = sorted(glob('./camvid/test/*'))
test_labels = sorted(glob('./camvid/testannot/*'))
batch_size = 10
Define the CamVidDataset Objects
train_dataset = CamVidDataset(train_images, train_labels, 512, 512)
val_dataset = CamVidDataset(val_images, val_labels, 512, 512)
test_dataset = CamVidDataset(test_images, test_labels, 512, 512)
Now we would create the DataLoader objects which would generate data from the dataset objects. The arguments that we would set here are:
-
batch_size: this denotes the number of samples contained in each generated batch. -
shuffle: If set toTrue, we will get a new order of exploration at each pass (or just keep a linear exploration scheme otherwise). Shuffling the order in which examples are fed to the classifier is helpful so that batches between epochs do not look alike. Doing so will eventually make our model more robust. -
num_workers: this denotes the number of processes that generate batches in parallel. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation).
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
The decode_segmap function accepts an image of shape (H, W) and a color dictionary denoting the BGR color codes to various objects in order for us to visualize the segmentation Masks.
def decode_segmap(image, color_dict):
label_colours = np.array([
color_dict['obj0'], color_dict['obj1'],
color_dict['obj2'], color_dict['obj3'],
color_dict['obj4'], color_dict['obj5'],
color_dict['obj6'], color_dict['obj7'],
color_dict['obj8'], color_dict['obj9'],
color_dict['obj10'], color_dict['obj11']
]).astype(np.uint8)
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, 12):
r[image == l] = label_colours[l, 0]
g[image == l] = label_colours[l, 1]
b[image == l] = label_colours[l, 2]
rgb = np.zeros((image.shape[0], image.shape[1], 3)).astype(np.uint8)
rgb[:, :, 0] = b
rgb[:, :, 1] = g
rgb[:, :, 2] = r
return rgb
The predict_rgb function takes the model(enet in our case), a tensor denoting a single image in the form (1, C, H, W) and the color_dict and gives us the visualizable prediction
def predict_rgb(model, tensor, color_dict):
with torch.no_grad():
out = model(tensor.float()).squeeze(0)
out = out.data.max(0)[1].cpu().numpy()
return decode_segmap(out, color_dict)
The color_dict is a dictionary where each object is mapped to its respective color code.
color_dict = {
'obj0' : [255, 0, 0], # Sky
'obj1' : [0, 51, 204], # Building
'obj2' : [0, 255, 255], # Posts
'obj3' : [153, 102, 102], # Road
'obj4' : [51, 0, 102], # Pavement
'obj5' : [0, 255, 0], # Trees
'obj6' : [102, 153, 153], # Signs
'obj7' : [204, 0, 102], # Fence
'obj8' : [102, 0, 0], # Car
'obj9' : [0, 153, 102], # Pedestrian
'obj10' : [255, 255, 255], # Cyclist
'obj11' : [0, 0, 0] # bicycles
}
Let us generate a batch from the train dataloader and visualize them along with their prediction using an untrained Enet.
x_batch, y_batch = next(iter(train_loader))
x_batch.shape, y_batch.shape
fig, axes = plt.subplots(nrows = 4, ncols = 3, figsize = (16, 16))
plt.setp(axes.flat, xticks = [], yticks = [])
c = 1
for i, ax in enumerate(axes.flat):
if i % 3 == 0:
ax.imshow(ToPILImage()(x_batch[c]))
ax.set_xlabel('Image_' + str(c))
elif i % 3 == 1:
ax.imshow(decode_segmap(y_batch[c], color_dict))
ax.set_xlabel('Ground_Truth_' + str(c))
elif i % 3 == 2:
ax.imshow(predict_rgb(enet, x_batch[c].unsqueeze(0).to(device), color_dict))
ax.set_xlabel('Predicted_Mask_' + str(c))
c += 1
plt.show()
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
The authors make use of a custom class weighing scheme defined as $w_{class} = \frac{1}{ln(c + p_{class})}$, where c is an additional hyper-parameter set to 1.02. The advantage of this weighing strategy is that in contrast to the inverse class probability weighing strategy, the weights are bounded as the probability approaches 0.
def get_class_weights(loader, num_classes, c=1.02):
_, y= next(iter(loader))
y_flat = y.flatten()
each_class = np.bincount(y_flat, minlength=num_classes)
p_class = each_class / len(y_flat)
return 1 / (np.log(c + p_class))
Now, we will set up the Criterion and Optimizer. The learning rate is set to 5e-4 with a weight decay of 2e-4 as mentioned in the paper.
class_weights = get_class_weights(train_loader, 12)
criterion = CrossEntropyLoss(
weight=torch.FloatTensor(class_weights).to(device)
)
optimizer = Adam(
enet.parameters(),
lr=5e-4,
weight_decay=2e-4
)
Next, we implement the training procedure:
- We first loop over the Main Training Loop for a particular number of epochs.
- For each epoch, we loop over the dataset for a particular number of steps which is equal to
length of dataset // batch_size. This is to ensure that the model gets a chance to look at most of the images in a single epoch. - In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes.
- Perform Backpropagation.
- Store the training loss.
- Log the traning results (optional).
- Perform Validation using the validation dataloader.
- Log the validation results (optional).
- Save the model states and results after several epochs (optional).
def train(
model, train_dataloader, val_dataloader,
device, criterion, optimizer, train_step_size, val_step_size,
visualize_every, save_every, save_location, save_prefix, epochs):
# Make sure that the checkpoint location exists
try:
os.mkdir(save_location)
except:
pass
train_loss_history, val_loss_history = [], []
# Training
for epoch in range(1, epochs + 1):
print('Epoch {}\n'.format(epoch))
# Training
start = time()
train_loss = 0
model.train()
# Step Loop
for step in tqdm(range(train_step_size)):
x_batch, y_batch = next(iter(train_dataloader))
x_batch = x_batch.squeeze().to(device)
y_batch = y_batch.squeeze().to(device)
optimizer.zero_grad()
out = model(x_batch.float())
loss = criterion(out, y_batch.long())
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss_history.append(train_loss / train_step_size)
print('\nTraining Loss: {}'.format(train_loss_history[-1]))
print('Training Time: {} seconds'.format(time() - start))
# Validation
val_loss = 0
model.eval()
for step in tqdm(range(val_step_size)):
x_val, y_val = next(iter(val_dataloader))
x_val = x_val.squeeze().to(device)
y_val = y_val.squeeze().to(device)
out = model(x_val.float())
out = out.data.max(1)[1]
val_loss += (y_val.long() - out.long()).float().mean()
val_loss_history.append(val_loss)
print('\nValidation Loss: {}'.format(val_loss))
# Visualization
if epoch % visualize_every == 0:
x_batch, y_batch = next(iter(train_loader))
fig, axes = plt.subplots(nrows = 4, ncols = 3, figsize = (16, 16))
plt.setp(axes.flat, xticks = [], yticks = [])
c = 1
for i, ax in enumerate(axes.flat):
if i % 3 == 0:
ax.imshow(ToPILImage()(x_batch[c]))
ax.set_xlabel('Image_' + str(c))
elif i % 3 == 1:
ax.imshow(decode_segmap(y_batch[c], color_dict))
ax.set_xlabel('Ground_Truth_' + str(c))
elif i % 3 == 2:
ax.imshow(predict_rgb(enet, x_batch[c].unsqueeze(0).to(device), color_dict))
ax.set_xlabel('Predicted_Mask_' + str(c))
c += 1
plt.show()
# Checkpoints
if epoch % save_every == 0:
checkpoint = {
'epoch' : epoch,
'train_loss' : train_loss,
'val_loss' : val_loss,
'state_dict' : model.state_dict()
}
torch.save(
checkpoint,
'{}/{}-{}-{}-{}.pth'.format(
save_location, save_prefix,
epoch, train_loss, val_loss
)
)
print('Checkpoint saved')
print(
'\nTraining Done.\nTraining Mean Loss: {:6f}\nValidation Mean Loss: {:6f}'.format(
sum(train_loss_history) / epochs,
sum(val_loss_history) / epochs
)
)
return train_loss_history, val_loss_history
train_loss_history, val_loss_history = train(
enet, train_loader, val_loader,
device, criterion, optimizer,
len(train_images) // batch_size,
len(val_images) // batch_size, 5,
5, './checkpoints', 'enet-model', 100
)
Now, let us visualize the results...
plt.plot(train_loss_history, color = 'b', label = 'Training Loss')
plt.plot(val_loss_history, color = 'r', label = 'Validation Loss')
plt.legend()
plt.show()
plt.plot(train_loss_history, color = 'b', label = 'Training Loss')
plt.legend()
plt.show()
plt.plot(val_loss_history, color = 'r', label = 'Validation Loss')
plt.legend()
plt.show()
We will be predicting with the weights at epoch 65 where both training and validation loss seems to be stable. This is done in order to avoid overfitting.
state_dict = torch.load('./checkpoints/enet-model-65-14.726004391908646--3.9436190128326416.pth')['state_dict']
enet.load_state_dict(state_dict)
Prediction on Training Data
x_batch, y_batch = next(iter(train_loader))
fig, axes = plt.subplots(nrows = 4, ncols = 3, figsize = (16, 16))
plt.setp(axes.flat, xticks = [], yticks = [])
c = 1
for i, ax in enumerate(axes.flat):
if i % 3 == 0:
ax.imshow(ToPILImage()(x_batch[c]))
ax.set_xlabel('Image_' + str(c))
elif i % 3 == 1:
ax.imshow(decode_segmap(y_batch[c], color_dict))
ax.set_xlabel('Ground_Truth_' + str(c))
elif i % 3 == 2:
ax.imshow(predict_rgb(enet, x_batch[c].unsqueeze(0).to(device), color_dict))
ax.set_xlabel('Predicted_Mask_' + str(c))
c += 1
plt.show()
Prediction on Validation Data
x_batch, y_batch = next(iter(val_loader))
fig, axes = plt.subplots(nrows = 4, ncols = 3, figsize = (16, 16))
plt.setp(axes.flat, xticks = [], yticks = [])
c = 1
for i, ax in enumerate(axes.flat):
if i % 3 == 0:
ax.imshow(ToPILImage()(x_batch[c]))
ax.set_xlabel('Image_' + str(c))
elif i % 3 == 1:
ax.imshow(decode_segmap(y_batch[c], color_dict))
ax.set_xlabel('Ground_Truth_' + str(c))
elif i % 3 == 2:
ax.imshow(predict_rgb(enet, x_batch[c].unsqueeze(0).to(device), color_dict))
ax.set_xlabel('Predicted_Mask_' + str(c))
c += 1
plt.show()